#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT

"""
Test script to verify that the validation logic is working correctly.
"""

from validation_utils import (
    is_tile_config_valid,
    is_trait_combination_valid,
    validate_warp_tile_combination,
)


def test_warp_tile_validation():
    """Test warp tile combination validation"""
    print("Testing warp tile combination validation...")

    # Get GPU name
    gpu_name = "gfx90a"

    # Test cases for fp16
    test_cases = [
        # (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid)
        ([4, 64, 8], False),  # Invalid - not in supported list
        ([4, 64, 16], True),  # Valid
        ([32, 32, 8], True),  # Valid
        ([16, 16, 16], True),  # Valid
        ([32, 32, 16], True),  # Valid
        ([16, 16, 32], True),  # Valid
        ([64, 4, 16], True),  # Valid
        ([128, 128, 128], False),  # Invalid - too large
    ]

    print("\nTesting fp16 warp tile combinations:")
    for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases:
        valid, msg = validate_warp_tile_combination(
            warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name
        )
        status = "PASS" if valid == expected else "FAIL"
        print(f"  [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}")
        if not valid and msg:
            print(f"    Reason: {msg}")


def test_trait_combinations():
    """Test trait combination validation"""
    print("\n\nTesting trait combination validation...")

    test_cases = [
        # (pipeline, epilogue, scheduler, expected_valid)
        ("mem", "default", "intrawave", True),
        ("mem", "cshuffle", "intrawave", True),
        ("compv3", "default", "interwave", False),  # Invalid combination
        ("compv3", "cshuffle", "interwave", False),  # Invalid combination
        ("compv4", "default", "interwave", False),  # Invalid combination
        ("compv4", "cshuffle", "interwave", False),  # Invalid combination
        ("compv3", "default", "intrawave", True),
        ("compv4", "cshuffle", "intrawave", True),
    ]

    print("\nTesting trait combinations:")
    for pipeline, epilogue, scheduler, expected in test_cases:
        valid = is_trait_combination_valid(pipeline, epilogue, scheduler)
        status = "PASS" if valid == expected else "FAIL"
        print(f"  {pipeline}-{epilogue}-{scheduler}: {valid} - {status}")


def test_full_tile_config_validation():
    """Test full tile configuration validation"""
    print("\n\nTesting full tile configuration validation...")

    # Test case that was failing in the build
    tile_m, tile_n, tile_k = 256, 256, 32
    warp_m, warp_n, warp_k = 1, 4, 1
    warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8

    print("\nTesting problematic configuration:")
    print(f"  Tile: {tile_m}x{tile_n}x{tile_k}")
    print(f"  Warp: {warp_m}x{warp_n}x{warp_k}")
    print(f"  WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")

    valid = is_tile_config_valid(
        tile_m,
        tile_n,
        tile_k,
        warp_m,
        warp_n,
        warp_k,
        warp_tile_m,
        warp_tile_n,
        warp_tile_k,
        "fp16",
        "fp16",
        "fp16",
        "mem",
    )

    print(f"  Valid: {valid}")
    print("  Expected: False (warp tile [4, 64, 8] is not supported for fp16)")

    # Test a valid configuration
    warp_tile_k = 16  # Change to valid value
    print("\nTesting corrected configuration:")
    print(f"  WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")

    valid = is_tile_config_valid(
        tile_m,
        tile_n,
        tile_k,
        warp_m,
        warp_n,
        warp_k,
        warp_tile_m,
        warp_tile_n,
        warp_tile_k,
        "fp16",
        "fp16",
        "fp16",
        "mem",
    )

    print(f"  Valid: {valid}")
    print("  Expected: True")


def main():
    """Run all tests"""
    print("=" * 60)
    print("GEMM Validation Test Suite")
    print("=" * 60)

    test_warp_tile_validation()
    test_trait_combinations()
    test_full_tile_config_validation()

    print("\n" + "=" * 60)
    print("Test suite completed")
    print("=" * 60)


if __name__ == "__main__":
    main()
